Is Flash Attention Stable?

llm
research paper
Author

Santosh Sawant

Published

May 10, 2024

Given the size and complexity of workloads, training Large Language Models (LLMs) often takes months together, across hundreds or thousands of GPUs. For example, LLaMA2’s 70-B parameter model, took 1,720,320 GPU hours to train. With such long training jobs, training instability has become increasingly problematic. As reported in works such as Google’s PaLM model, training instability often manifests itself in the form of loss spikes occurring up to 20 times throughout training. These loss spikes are costly, as they often cause interrupts in the training process, requiring training to stop and restart

One under-explored potential cause of training instability is numeric deviation. Numeric deviation between an optimization and its corresponding baseline can lead to the gradual accumulation of errors, which over the course of training have the potential to culminate in loss spikes that require a resetting of the model state.

To understand this numeric deviation in training optimizations researchers from Meta developed a principled quantitative approach consists of two phases (1) a numerical microbenchmark of the Flash Attention operation was developed, which allows for the experimentation of different numerical precisions, as well as the testing of various optimizations throughout the algorithm. this framework allows for the direct comparison of the Attention Matrix output between Baseline Attention, Flash Attention, and numeric re-implementation. (2) a data-driven analysis based on Wasserstein distance were used to contextualize this numeric difference via examining model weight changes over the course of training.

After applying the above framework on widely-adopted Flash Attention optimization it was found that flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at BF16 when measured during an isolated forward pass. Further, data-driven analysis based on the Wasserstein Distance have provided upper bounds on how this numeric deviation impacts model weights during training, finding that the numerical deviation present in Flash Attention is 2-5 times less significant than low-precision training.

In conclusion, investigations underscore the importance of developing a principled approach to not only quantify, but contextualize, the impact of training optimizations on numeric deviation. By constructing proxies to put this numeric deviation in context, this paper aims to reason about the likelihood of downstream model effects (i.e training instability) that are traditionally difficult to measure.

Paper: https://lnkd.in/gpSdZu99